# -*- coding: utf-8 -*-
import numpy as np
import pandas as pd
from sklearn.utils import shuffle
from sklearn import preprocessing
from sklearn.metrics import roc_curve, roc_auc_score, auc
import matplotlib.pyplot as plt


def prepare_data(data_filename, landslide_test_size=0.3, notlandslide_test_size=0.3, y_label="y_isLandslide_"):
    '''
        将csv文件中的数据，按照比例分割为train_df, test_df

        :param data_filename: string
            原始csv文件路径

        :param landslide_test_size: float
            预测数据集中滑坡样本数目占滑坡样本总数目的比例

        :param notlandslide_test_size: float
            预测数据集中非滑坡样本数目占非滑坡样本总数目的比例

        :param y_label: str
            标签列名称

        :return: DataFrame
            返回train_df, test_df，分别为训练集和测试集

    '''

    # %% 从csv读取数据
    raw_df = pd.read_csv(data_filename, encoding="gb2312")

    # 取滑坡数据，并将其打乱顺序
    landslide_df = raw_df[raw_df[y_label] == 1]
    shuffle_landslide_df = shuffle(landslide_df)

    # 取非滑坡数据，并将其打乱顺序
    notlandslide_df = raw_df[raw_df[y_label] == 0]
    shuffle_notlandslide_df = shuffle(notlandslide_df)

    # 根据landslide_test_size取一部分滑坡数据
    #    test_landslide_df = shuffle_landslide_df.sample(frac=landslide_test_size)
    landslide_row_num = shuffle_landslide_df.shape[0]
    sel_row_num = int(landslide_row_num * landslide_test_size)
    test_landslide_df = shuffle_landslide_df.iloc[0:sel_row_num, :]  # 用于预测
    train_landslide_df = shuffle_landslide_df.iloc[sel_row_num:, :]  # 用于训练

    # 根据landslide_test_size取一部分滑坡数据
    #    test_landslide_df = shuffle_landslide_df.sample(frac=landslide_test_size)
    sel_notlandslide_row_num = int(shuffle_notlandslide_df.shape[0] * notlandslide_test_size)
    test_notlandslide_df = shuffle_notlandslide_df.iloc[0:sel_notlandslide_row_num, :]  # 用于预测
    train_notlandslide_df = shuffle_notlandslide_df.iloc[sel_notlandslide_row_num:, :]  # 用于训练

    # 数据合并，分成训练集和测试集
    test_df = test_landslide_df.append(test_notlandslide_df)
    train_df = train_landslide_df.append(train_notlandslide_df)

    # %% 返回数据
    return train_df, test_df


def plot_roc_curve(year, y_test, y_score, pos_label=1,
                 classifier_name='logistic',
                 process_units='raster',
                 results_dir = None,
                 data_version = ''):
    '''
    绘制某个年份的ROC曲线
    :param year: int
        年份参数

    :param y_test: list
        实际类别标签

    :param y_score: list
        类别得分

    :param pos_label: list
        正类标签

    :param classifier_name: str
        分类器，logistics或者是svm

    :param process_units: str
        处理单元，raster或者是slope，分别为栅格单元和斜坡单元

    :param results_dir: str
        结果保存的目录

    :param data_version: str
        数据版本，用于描述数据

    :return: double
        auc值

    '''
    # %% Compute ROC curve and ROC area for each class
    fpr, tpr, thresholds = roc_curve(y_test, y_score, pos_label=pos_label)
    fpr = np.array(fpr)
    tpr = np.array(tpr)
    auc_val = auc(fpr, tpr)

    plt.figure()
    lw = 1
    plt.plot(fpr, tpr, color='darkorange', lw=lw, label='ROC curve (area = %0.3f)' % auc_val)
    plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC curve({0})'.format(year))
    plt.legend(loc="lower right")

    # python 3.6的f-string
    # plt.savefig(f'results/{process_units}-ROC-{classifier_name}-{year}.png', dpi=120)
    if results_dir is not None:
        if data_version == '':
            roc_pic_path = "{0}/{1}-ROC-{2}-{3}.png".format(results_dir, process_units, classifier_name, year)
        else:
            roc_pic_path = "{0}/{1}-ROC-{2}-{3}-{4}.png".format(results_dir,process_units, classifier_name, year,
                                                                data_version)
        plt.savefig(roc_pic_path, dpi=120)
        # plt.show()

    return auc_val


def split_to_xy(df,
                class_col_name="y_isLandslide",
                normalized=True):
    '''
        将输入的df(DataFrame)转为X,y这种形式，其中df的第一列为ID列，不应该包含到X中，最后一类为标签列，即y

        :param df: DataFrame
            输入的pandas库的DataFrame数组

        :param class_col_name: str
            标签列名

        :param normalized: bool
            X是否归一化

        :return: tuple
            返回x,y这种形式的元组，其中x和y均为ndarray

    '''

    # %% 变为X和y两部分
    #    GeoID = df[GeoID_name].values  #第一列，ID列
    y = df[class_col_name].values  # 第二列，目标标签列，即是否是滑坡区域

    X_raw = df.values[:,1:-1]


    # %% 看是否进行，归一化操作
    if normalized:
        X = preprocessing.scale(X_raw)
    else:
        X = X_raw

    return X, y

# if __name__=="__main__":
#     df = pd.read_csv('RasterUnitsV4_2003_2003.csv', encoding='gb2312')
#     print(df.columns)
#     split_to_xy(df,normalized=False)
